import copy
import re
import time
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast

from rich.console import Console
from rich.markup import escape

from vision_agent.agent import Agent
from vision_agent.agent.vision_agent_prompts_v3 import get_init_prompt
from vision_agent.configs import Config
from vision_agent.lmm import LMM, AnthropicLMM
from vision_agent.models import AgentMessage, Message
from vision_agent.utils.agent import (
    add_media_to_chat,
    capture_media_from_exec,
    convert_message_to_agentmessage,
    extract_tag,
    print_code,
    remove_installs_from_code,
)
from vision_agent.utils.execute import CodeInterpreter, CodeInterpreterFactory

CONFIG = Config()
MAX_IMAGES = 10
_CONSOLE = Console()


class DefaultImports:
    imports = [
        "import os",
        "import numpy as np",
        "import cv2",
        "from typing import *",
        "from pillow_heif import register_heif_opener",
        "from vision_agent.tools import load_image",
        "from vision_agent.tools.planner_v3_tools import instance_segmentation, ocr, depth_estimation, visualize_bounding_boxes, visualize_segmentation_masks, get_crops, rotate_90, display_image, iou",
        "register_heif_opener()",
        "import matplotlib.pyplot as plt",
    ]

    @staticmethod
    def prepend_imports(code: str) -> str:
        return "\n".join(DefaultImports.imports) + "\n\n" + code


def run_chat(
    model: LMM,
    chat: List[AgentMessage],
    kwargs: Optional[Dict[str, Any]] = None,
) -> str:
    chat = copy.deepcopy(chat)
    formatted_chat = []
    for c in chat:
        if c.role in ["user", "observation", "final_observation", "error_observation"]:
            role = "user"
        else:
            role = "assistant"
        formatted_chat.append({"role": role, "content": c.content, "media": c.media})
    response = cast(str, model(formatted_chat, **(kwargs or {})))  # type: ignore
    return response


def strip_signature(response: str) -> str:
    signature = extract_tag(response, "signature")
    if signature is not None:
        response = response.replace(f"<signature>{signature}</signature>", "")
    return response


def strip_signature_from_agentmessage(
    response: AgentMessage,
) -> AgentMessage:
    return AgentMessage(
        role=response.role,
        content=strip_signature(response.content),
        media=response.media,
    )


def fix_xml_code_tags(response: str) -> str:
    start_tag = "```python"
    end_tag = "```"

    start_index = response.find(start_tag)
    if start_index != -1:
        end_index = response.find(end_tag, start_index + len(start_tag))
        if end_index != -1:
            # Extract the code content
            code_content = response[start_index + len(start_tag) : end_index].strip()
            # Replace the markdown block with XML <code> tags
            response = (
                response[:start_index]
                + f"<code>\n{code_content}\n</code>"
                + response[end_index + len(end_tag) :]
            )

    # Original logic to fix potentially missing XML tags
    if "<answer>" in response and "</answer>" not in response:
        response += "</answer>"

    # Ensure <code> tags are closed if they exist (could be pre-existing or just added)
    if "<code>" in response and "</code>" not in response:
        response += "</code>"

    return response


def strip_extra_content(response: str) -> str:
    code_pos = [i.start() for i in re.finditer("<code>", response)]
    if len(code_pos) > 0:
        thinking_start = response.find("<thinking>")
        thinking_end = response.find("</thinking>", thinking_start)
        signature_start = response.find("<signature>")
        signature_end = response.find("</signature>", signature_start)
        code_start = response.find("<code>")
        code_end = response.find("</code>", code_start)
        return (
            response[thinking_start : thinking_end + len("</thinking>")]
            + (
                response[signature_start : signature_end + len("</signature>")]
                if signature_start != -1
                else ""
            )
            + response[code_start : code_end + len("</code>")]
        )
    return response


def run_code(
    code: str,
    code_interpreter: CodeInterpreter,
) -> Tuple[str, List[str], float]:
    code = remove_installs_from_code(code)
    start = time.time()
    execution = code_interpreter.exec_cell(DefaultImports.prepend_imports(code))
    end = time.time()

    obs = execution.text(include_logs=True).strip()
    result_images = capture_media_from_exec(execution)
    max_images_to_include = MAX_IMAGES
    if result_images:
        max_images_to_include = min(len(result_images), MAX_IMAGES)
        return_images = result_images[:max_images_to_include]
        image_note = f"\n\n[{len(return_images)} images were generated by your code and are included with this message]"
        obs += image_note
    return_images = result_images[:max_images_to_include] if result_images else []

    return obs, return_images, end - start


def format_obs_message(
    obs: str,
    turn: int,
    turns: int,
) -> str:
    obs_message = f"[Turn {turn + 1}/{turns}] Code execution result:\n{obs}"
    if turn == turns - 2:
        warning_msg = "\n\n⚠️CRITICAL: The next turn will be your FINAL turn. Please make sure to provide your final answer in <answer> tags in your next response, no need to incude <code> tags. Rember to print out final answers without any explaination, it could be a single word, number, price or a list of bounding boxes of object detection."
        obs_message += warning_msg
    return obs_message


class VisionAgentV3(Agent):
    def __init__(
        self,
        agent: Optional[LMM] = None,
        hil: bool = False,
        verbose: bool = False,
        code_sandbox_runtime: Optional[str] = None,
        update_callback: Callable[[Dict[str, Any]], None] = lambda x: None,
    ) -> None:
        if agent is None:
            self.agent = AnthropicLMM(
                model_name="claude-sonnet-4-20250514", max_tokens=8192
            )
        self.kwargs = {
            "thinking": {"type": "enabled", "budget_tokens": 4096},
            "stop_sequences": ["</code>", "</answer>"],
        }

        self.turns = 7
        self.verbose = verbose
        self.code_sandbox_runtime = code_sandbox_runtime
        self.update_callback = update_callback

    def __call__(
        self,
        input: Union[str, List[Message]],
        media: Optional[Union[str, Path]] = None,
    ) -> str:
        msg = convert_message_to_agentmessage(input, media)
        return self.chat(msg)[-1].content

    def chat(
        self,
        chat: List[AgentMessage],
        code_interpreter: Optional[CodeInterpreter] = None,
    ) -> List[AgentMessage]:
        chat = copy.deepcopy(chat)
        if not chat or chat[-1].role not in {"user", "interaction_response"}:
            raise ValueError(
                f"Last chat message must be from the user or interaction_response, got {chat[-1].role}."
            )

        return_chat = []
        with (
            CodeInterpreterFactory.new_instance(self.code_sandbox_runtime)
            if code_interpreter is None
            else code_interpreter
        ) as code_interpreter:
            int_chat, _, _ = add_media_to_chat(
                chat, code_interpreter, append_to_prompt=False
            )
            init_prompt = get_init_prompt(
                model="",
                turns=self.turns,
                question=int_chat[0].content,
                category="",
                image_path=str(int_chat[0].media),
            )
            return_chat.append(
                AgentMessage(role="user", content=init_prompt, media=int_chat[0].media)
            )

            for turn in range(self.turns):
                response = run_chat(self.agent, return_chat, self.kwargs)
                response = fix_xml_code_tags(response)
                response = strip_extra_content(response)

                return_chat.append(AgentMessage(role="assistant", content=response))
                self.update_callback(
                    strip_signature_from_agentmessage(return_chat[-1]).model_dump()
                )

                code = extract_tag(response, "code")
                thoughts = extract_tag(response, "thinking")
                answer = extract_tag(response, "answer")

                if self.verbose:
                    _CONSOLE.print(
                        f"[bold cyan]Step {turn}/{self.turns}[/bold cyan]\n"
                        f"[green]{thoughts}[/green]\n"
                    )
                    if answer is not None:
                        _CONSOLE.print(
                            f"[magenta]Final answer: {escape(answer)}[/magenta]\n"
                        )
                    if code is not None:
                        print_code("Code:", code)

                if answer is not None:
                    # final answer is in the previous response message so no need to add
                    # add it to the return_chat
                    self.update_callback(
                        AgentMessage(
                            role="final_observation",
                            content=f"<answer>{answer}</answer>",
                        ).model_dump()
                    )
                elif code is not None:
                    obs, images, latency = run_code(code, code_interpreter)
                    obs = format_obs_message(obs, turn, self.turns)
                    _CONSOLE.print(
                        f"[bold cyan]Code execution took {latency:.2f} seconds.[/bold cyan]\n"
                        f"[yellow]{escape(obs)}[/yellow]\n"
                    )
                    return_chat.append(
                        AgentMessage(role="observation", content=obs, media=images)
                    )
                    self.update_callback(
                        strip_signature_from_agentmessage(return_chat[-1]).model_dump()
                    )
        return return_chat

    def log_progress(self, data: Dict[str, Any]) -> None:
        pass
